Molecular Dynamics in Dex
This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in
Dex. For now, the structure of the two implementations is pretty close. However, details
look different.
def truncate(x : Float ) -> Float =
case x < 0. 0 of
True -> - floor (- x )
False -> floor x
def pmod(x : Float , y : Float ) -> Float =
x - floor (x / y ) * y
def fmod(x : Float , y : Float ) -> Float =
x - truncate (x / y ) * y
def Vec (dim | Ix ) -> Type = dim => Float
def sq_norm(r : Vec dim ) -> Float given (dim | Ix ) =
sum $ for i . r [i ] * r [i ]
def norm(r : Vec dim ) -> Float given (dim | Ix ) =
sqrt $ sq_norm r
This computes a size for a box of the given number of dimensions,
such that the given number of particles will fill it with the given
density.
def box_size_at_number_density(particle_count : Int , density : Float , dim : Int ) -> Float =
pow ((i_to_f particle_count ) / density ) (1. 0 / (i_to_f dim ) )
def Displacement (dim | Ix ) -> Type = (Vec dim, Vec dim ) -> Vec dim
def Shift (dim | Ix ) -> Type = (Vec dim, Vec dim ) -> Vec dim
def free_displacement() ->> Displacement dim given (dim | Ix ) =
\ r_1 r_2 . r_1 - r_2
def free_shift() ->> Shift dim given (dim | Ix ) =
\ r dr . r + dr
def periodic_wrap(box : Float , dr : dim => Float ) -> (dim => Float ) given (dim | Ix ) =
for i . (pmod (dr [i ] + box / 2. 0 ) box ) - box / 2. 0
def periodic_displacement(box : Float ) -> Displacement dim given (dim | Ix ) =
\ r_1 r_2 . periodic_wrap box (r_1 - r_2 )
def periodic_shift(box : Float ) -> Shift dim given (dim | Ix ) =
\ r dr . for i . pmod (r [i ] + dr [i ] ) box
def Energy (n | Ix , dim | Ix ) -> Type = (n => Vec dim ) -> Float
The "soft-sphere" energy, as a function of the displacement r.
def soft_sphere(ε : Float , α : Float , σ : Float , r : Float ) -> Float =
case r < σ of
True -> ε / α * pow (1 - r / σ ) α
False -> 0. 0
def harmonic_soft_sphere(sigma : Float ) -> (Float ) -> Float = \ r .
soft_sphere 1. 0 2. 0 sigma r
Here we have a naive pairwise energy function constructor, that
promotes a two-particle energy to a whole-system energy by just
applying it to every pair of distinct particles. We start with this
to have something to test with.
def pair_energy(
energy : (Float ) -> Float ,
displacement : Displacement dim,
r : n => Vec dim
) -> Float given (n | Ix , dim | Ix ) =
sum $ for i . sum for j : (..< i ) .
energy $ norm $ displacement r[i ] r[inject j ]
Now we develop the FIRE Descent algorithm.
struct FireDescentState (n | Ix , dim | Ix ) =
R : (n => Vec dim )
V : (n => Vec dim )
F : (n => Vec dim )
dt: Float
alpha: Float
n_pos: Int
def fire_descent_init(
dt : Float ,
alpha : Float ,
energy : Energy (n , dim ) ,
r : n => Vec dim
) -> FireDescentState n dim given (n | Ix , dim | Ix ) =
force = \ rp . - (grad energy ) rp
V = for i : n j : dim . 0. 0
F = force r
n_pos : Int = 0
FireDescentState (r , V , F , dt , alpha , n_pos )
def fire_descent_step(
shift : Shift dim,
energy : Energy n dim,
state : FireDescentState n dim
) -> FireDescentState n dim given (n | Ix , dim | Ix ) =
dt_start = 0. 1
dt_max = 0. 4
n_min = 5
f_inc = 1. 1
f_dec = 0. 5
f_alpha = 0. 99
alpha_start = 0. 1
ε = 0. 000000001
force = \ r . - (grad energy ) r
dt = state . dt
alpha = state . alpha
n_pos = state . n_pos
R = for i . shift state. R [i ] (state . V [i ] *. dt + state . F [i ] *. pow dt 2 )
F _old = state . F
F = force R
V = state . V + dt * 0. 5 .* (F _old + F )
F _norm = sqrt $ sum $ sum for i j . pow F [i , j ] 2
V _norm = sqrt $ sum $ sum for i j . pow V [i , j ] 2
V = V + alpha .* (F *. V _norm / (F _norm + ε ) - V )
FdotV = sum $ sum for i j . F [i , j ] * V [i , j ]
(n_pos , dt , alpha ) = if FdotV >= 0. 0
then
dt' = if n_pos >= n_min then (min (dt * f_inc ) dt_max ) else dt
alpha' = if n_pos >= n_min then (alpha * f_alpha ) else alpha
(n_pos + 1 , dt' , alpha' )
else (0 , dt * f_dec , alpha_start )
V = if FdotV >= 0. 0 then V else zero
FireDescentState (R , V , F , dt , alpha , n_pos )
Now a tool to draw a two-dimensional system, where each particle is a
disk of given size.
def draw_system(radius , r : n => Vec TwoDimensions ) -> Diagram given (n | Ix ) =
disks = concat_diagrams for i .
circle radius | move_xy (r [i , (0 @ TwoDimensions ) ] , r [i , (1 @ TwoDimensions ) ] )
disks
Example
Initialize a system randomly.
L _small = box_size_at_number_density (n_to_i N _small ) 1. 2 (n_to_i d )
R _init_small = rand_mat N _small d (\ k . L _small * rand k ) (new_key 0 )
The initial state of our random system
%time
:html render_svg (draw_system 0. 5 R _init_small ) (Point (0. 0 , 0. 0 ) , Point (L _small, L _small ) )
Compile time: 236.679 ms
Run time: 26.521 ms
Define energy function. Note the preiodic_displacement, which means
our system will be evolving on a torus.
def energy(pos : n => d => Float ) -> Float given (n | Ix , d | Ix )
= pair_energy (harmonic_soft_sphere 1. 0 ) (periodic_displacement L _small ) pos
Here's the initial energy we compute for our system.
:t energy R _init_small
Float32
energy R _init_small
74.69006
state_small = fire_descent_init 0. 1 0. 1 energy R _init_small
and test one step of minimization. The energy decreases from the
initial, as expected:
energy $ fire_descent_step (free_shift , energy , state_small ) . R
71.78407
Now we can test that our code basically works by running 100 steps of
minimization.
%time
(state_small' , energies ) = scan state_small \ i : (Fin 100 ) s .
s' = fire_descent_step (periodic_shift L _small ) energy s
(s' , energy $ s' . R )
Compile time: 590.614 ms
Run time: 817.194 ms
Here's how the energy decreases over time.
%time
:html show_plot $ y_plot energies
Compile time: 389.418 ms
Run time: 5.340 ms
Here's what the system looks like after minimization.
%time
:html render_svg (draw_system 0. 5 (state_small' . R ) ) (Point (0. 0 , 0. 0 ) , Point (L _small, L _small ) )
Compile time: 235.188 ms
Run time: 25.306 ms
The above pair_energy function will compute the influence of every atom on
every other atom, regardless of how far apart they are.
To simulate more efficiently, we'd like to approximate the pairwise
energy with an energy that only includes contributions from atoms that
are close enough to each other that we wish not to neglect them.
This is a two-step operation:
Break the simulation volume into a grid of cells, and do a linear
pass over the atoms to group them by which cell each is in.
Traverse every pair of adjacent cells and compute energy terms for
every pair of atoms only in those cells, and no others.
We start with an abstraction of an incrementally growable list. To
get O(1) insertion at the end, we (currently) have to give an upper
bound for the list's size.
struct BoundedList (n | Ix , a | Data ) =
lim : n
buf : n => a
def unsafe_next_index(ix : n ) -> n given (n | Ix ) =
unsafe_from_ordinal $ ordinal ix + 1
def empty_bounded_list(dummy_val : a ) -> BoundedList n a given (a | Data , n | Ix ) =
BoundedList (unsafe_from_ordinal 0, for _ . dummy_val )
def bd_push(ref : Ref h (BoundedList n a ), x : a ) -> {State h } () given (h , a | Data , n | Ix ) =
sz = get ref. lim
if ordinal sz < size n
then
ref . buf ! sz := x
ref . lim := unsafe_next_index sz
else
todo
def as_list(lst : BoundedList n a ) -> List a given (a | Data , n | Ix ) =
n_result = ordinal lst. lim
AsList _ $ for i : (Fin n_result ) . lst . buf [unsafe_from_ordinal $ ordinal i ]
We define a single index for the whole grid.
def GridIx (dim | Ix , grid_size ) -> Type = dim => (Fin grid_size )
A cell list is now just a BoundedList of the (indices of) the atoms
that appear in each cell in the grid.
def CellTable (dim | Ix , grid_size , bucket_size , atom_ix | Data ) -> Type =
GridIx dim grid_size => BoundedList (Fin bucket_size ) atom_ix
def target_cell(cell_size : Float , atom : Vec dim )
-> GridIx dim grid_size given (dim | Ix , grid_size ) =
for dim . from_ordinal $ f_to_n $ atom [dim ] / cell_size
def traverse_cells(
grid_size , cell_size ,
atoms : atom_ix => Vec dim,
action : (atom_ix , GridIx (dim , grid_size ) ) -> {| eff } ()
) -> {| eff } () given (dim | Ix , atom_ix | Ix , eff ) =
for _ ix .
cell = target_cell cell_size atoms[ix ]
action ix cell
Here is the actual cell table computation:
def cell_table(
grid_size , bucket_size , cell_size ,
atoms : atom_ix => Vec dim
) -> CellTable (dim , grid_size , bucket_size , atom_ix )
given (dim | Ix , atom_ix | Ix ) =
yield_state (for _ . empty_bounded_list $ unsafe_from_ordinal 0 ) \ ref .
traverse_cells grid_size cell_size atoms \ ix cell .
bd_push ref! cell ix
def cell_table_bucket_size(
grid_size , cell_size , atoms : atom_ix => Vec dim
) -> Nat given (dim | Ix , atom_ix | Ix ) =
bucket_sizes = yield_state (for _ : (GridIx dim grid_size ) . 0 ) \ ref .
traverse_cells grid_size cell_size atoms \ ix cell .
ref ! cell := get ref! cell + 1
maximum bucket_sizes
Let's test it out on our "small" system:
def grid_params(L , desired_cell_size ) =
cells_per_side = floor (L / desired_cell_size )
cell_size = L / cells_per_side
grid_size = unsafe_i_to_n $ f_to_i cells_per_side
(cell_size , grid_size )
(cell_size , grid_size ) = grid_params L _small desired_cell_size
%time
tbl = cell_table grid_size bucket_size cell_size $ state_small' . R
Compile time: 107.760 ms
Run time: 31.300 us
We have a table of cells with atoms in them
:t tbl
(((Fin 2) => Fin 20) => BoundedList (Fin 10) (Fin 500))
And here are the atoms in the 0th cell.
as_list tbl[unsafe_from_ordinal 0 ]
(AsList 2 [33, 237])
Now let's compute pairs of neighbors from our cell list
We'll specialize to two dimensions for now, but broadening to
more is not difficult.
cell_neighbors_2d = [[- 1 , - 1 ] , [- 1 , 0 ] , [- 1 , 1 ] , [0 , - 1 ] ,
[0 , 0 ] , [0 , 1 ] , [1 , - 1 ] , [1 , 0 ] , [1 , 1 ] ]
:t cell_neighbors_2d
((Fin 9) => (Fin 2) => Int32)
def torus_offset(ix : (Fin n ) , offset : Int ) -> (Fin n ) given (n ) =
unsafe_from_ordinal $ unsafe_i_to_n $
mod (n_to_i (ordinal ix ) + offset ) (n_to_i n )
def traverse_pairs(
tbl : CellTable (TwoDimensions , grid_size , bucket_size , atom_ix ) ,
atoms : atom_ix => Vec TwoDimensions ,
action : (atom_ix , atom_ix ) -> {| eff } ()
) -> {| eff } () given (grid_size , bucket_size , atom_ix | Ix , eff ) =
for _ cell_ix : (GridIx TwoDimensions grid_size ) .
for _ nb .
displacement = cell_neighbors_2d [nb ]
neighbor_ix = for dim . torus_offset cell_ix[dim ] displacement[dim ]
AsList (sz_atoms_1 , atoms_1 ) = as_list tbl[cell_ix ]
for _ atom1_ix : (Fin sz_atoms_1 ) .
atom1 = atoms_1 [atom1_ix ]
AsList (sz_atoms_2 , atoms_2 ) = as_list tbl[neighbor_ix ]
for _ atom2_ix : (Fin sz_atoms_2 ) .
atom2 = atoms_2 [atom2_ix ]
action atom1 atom2
The neighbor list computation. The point of the exercise is that
this is not O(#atoms^2), but rather O(#cells) * 9 * O(#atoms per
cell^2), because it only considers atoms from adjacent cells as
potential neighbors.
def neighbor_list(
neighbor_list_size ,
tbl : CellTable (TwoDimensions , grid_size , bucket_size , atom_ix ) ,
is_neighbor : (atom_ix , atom_ix ) -> Bool ,
atoms : atom_ix => Vec TwoDimensions
) -> BoundedList (Fin neighbor_list_size, (atom_ix , atom_ix ) )
given (grid_size , bucket_size , atom_ix | Ix ) =
yield_state (empty_bounded_list ((from_ordinal 0, from_ordinal 0 ) ) ) \ ref .
traverse_pairs (tbl , atoms ) \ atom1 atom2 .
if is_neighbor (atom1 , atom2 )
then bd_push (ref , (atom1 , atom2 ) )
def count_neighbor_list_size(
tbl : CellTable (TwoDimensions , grid_size , bucket_size , atom_ix ) ,
is_neighbor : (atom_ix , atom_ix ) -> Bool ,
atoms : atom_ix => Vec TwoDimensions
) -> Nat given (grid_size , bucket_size , atom_ix | Ix ) =
yield_accum (AddMonoid Nat ) \ ref .
traverse_pairs tbl atoms (\ atom1 atom2 .
if is_neighbor atom1 atom2
then ref += 1 )
def periodic_near(
desired_cell_size ,
L ,
atoms : atom_ix => Vec TwoDimensions
) -> (atom_ix , atom_ix ) -> Bool given (atom_ix | Ix ) = \ atom1 atom2 .
dist = norm $ periodic_displacement (L ) (atoms [atom1 ] , atoms [atom2 ] )
dist < desired_cell_size
Let's test this one out on our "small" system.
%time
res = (neighbor_list 4000 tbl
(periodic_near 1. 0 L _small $ state_small' . R ) $ state_small' . R )
Compile time: 120.297 ms
Run time: 1.318 ms
In that configuration, we find this many pairs of neighbors:
AsList (k , _ ) = as_list res
Now that we have the concept of neighbor lists, we cen define a
variant of pair_energy that only considers atoms that the neighbor
list says are close.
def pair_energy_nl(
energy : (Float ) -> Float ,
displacement : Displacement dim,
r : n => Vec dim,
neighbors : List (n , n )
) -> Float given (dim | Ix , n | Ix ) =
AsList (k , nbrs ) = neighbors
sum for i .
(a1ix , a2ix ) = nbrs [i ]
case (ordinal a1ix ) < (ordinal a2ix ) of
True -> energy $ norm $ displacement r[a1ix ] r[a2ix ]
False -> 0. 0
def energy_nl(
L ,
neighbors : List (n , n )
) -> (n => d => Float ) -> Float given (d | Ix , n | Ix ) = \ pos .
pair_energy_nl (harmonic_soft_sphere 1. 0 ) (periodic_displacement L ) pos neighbors
And here's the check that it computes near-identical results to our
original, fully pairwise energy function.
energy_nl (L _small, as_list res ) (state_small' . R )
1.230771
energy (state_small' . R )
1.230771
def just_neighbor_list(
desired_cell_size ,
L ,
atoms : n => (Fin 2 ) => Float
) -> (List (n , n ) ) given (n | Ix ) =
(cell_size , grid_size ) = grid_params L desired_cell_size
bucket_size = cell_table_bucket_size grid_size cell_size atoms
tbl = cell_table grid_size bucket_size cell_size atoms
is_neighbor = periodic_near desired_cell_size L atoms
neighbor_list_sz = count_neighbor_list_size tbl is_neighbor atoms
res = neighbor_list neighbor_list_sz tbl is_neighbor atoms
as_list res
The neighbor list energy function works as an argument to our FIRE
Descent algorithm (provided the neighbor list uses the same atoms).
state_nl =
energy_func = (energy_nl L _small $ just_neighbor_list 1. 0 L _small R _init_small )
fire_descent_init 0. 1 0. 1 energy_func R _init_small
energy R _init_small
74.69006
energy $ fire_descent_step (free_shift , energy , state_small ) . R
71.78407
def fast_any(f : (n ) -> {| eff } Bool ) -> {| eff } Bool given (n | Ix , eff ) =
iter \ ct .
if ct >= size n
then Done False
else if f (ct @ n ) then Done True else Continue
And now that this basically works, we can package the whole thing up
as a simulation. We have another trick here: we compute the neighbor
list with an extra "halo", treating atoms as neighbors if they are
distance 1 + halo from each other, rather than just the interaction
range 1. This way, we only have to recompute the neighbor list when
some atom moves more than halo/2 away from where it was when the
neighbor list is computed, because otherwise it remains a sound
approximation.
def simulate(
displacement : Displacement TwoDimensions ,
halo_size ,
L ,
time | Ix ,
state : FireDescentState (atom_ix , TwoDimensions )
) -> {IO } (FireDescentState (atom_ix , TwoDimensions ) , time => Float ) given (atom_ix | Ix ) =
with _state state. R \ saved_atoms_ref .
nbrs = just_neighbor_list (1. 0 + halo_size ) L $ state . R
AsList (k , _ ) = nbrs
print $ show k <> " initial neighbor list size"
with _state nbrs \ saved_neighbors_ref .
swap $ run_state state \ s_ref . for i .
s = get s_ref
new_atoms = s . R
rebuild = fast_any \ i .
2 * norm (displacement (get saved_atoms_ref! i ) new_atoms[i ] ) > halo_size
if rebuild then
saved_atoms_ref := new_atoms
nbrs = just_neighbor_list (1. 0 + halo_size ) L new_atoms
saved_neighbors_ref := nbrs
AsList (k , _ ) = nbrs
print $ show k <> " new neighbor list size"
nbrs = get saved_neighbors_ref
s' = fire_descent_step (periodic_shift L ) (energy_nl L nbrs ) s
s_ref := s'
energy_nl (L , nbrs ) (s' . R )
Let's check that it works on our test system from before
%time
(state_nl' , energies_nl ) =
unsafe_io \. simulate (periodic_displacement L _small ) 0. 5 L _small (Fin 100 ) state_small
4568 initial neighbor list size
4614 new neighbor list size
4564 new neighbor list size
4376 new neighbor list size
4346 new neighbor list size
4266 new neighbor list size
4178 new neighbor list size
4140 new neighbor list size
4100 new neighbor list size
4028 new neighbor list size
4006 new neighbor list size
3922 new neighbor list size
3868 new neighbor list size
3810 new neighbor list size
3762 new neighbor list size
3720 new neighbor list size
3656 new neighbor list size
3640 new neighbor list size
3602 new neighbor list size
3572 new neighbor list size
3552 new neighbor list size
Compile time: 1.758 s
Run time: 55.993 ms
%time
:html show_plot $ y_plot energies_nl
Compile time: 406.720 ms
Run time: 5.293 ms
%time
:html render_svg (draw_system 0. 5 state_nl'. R ) (Point (0. 0 , 0. 0 ) , Point (L _small, L _small ) )
Compile time: 233.455 ms
Run time: 24.784 ms
But of course the point of the exercise is that this now scales up to
larger systems because it avoids the quadratic energy computation.
N _large = if not (dex_test_mode () ) then 50000 else 500
L _large = box_size_at_number_density (n_to_i N _large ) 1. 2 (n_to_i d )
R _init_large = rand_mat N _large d (\ k . L _large * rand k ) (new_key 0 )
Initial state (we render the atoms smaller now so they don't over-plot too badly).
%time
:html render_svg (draw_system 0. 2 R _init_large ) (Point (0. 0 , 0. 0 ) , Point (L _large, L _large ) )
Compile time: 648.791 ms
Run time: 2.525 s
state_large =
energy_func = (energy_nl L _large $ just_neighbor_list 1. 0 L _large R _init_large )
fire_descent_init 0. 1 0. 1 energy_func R _init_large
%time
(state_large_nl' , energies_large_nl ) =
unsafe_io \. simulate (periodic_displacement L _large ) 0. 5 L _large (Fin 100 ) state_large
474778 initial neighbor list size
474090 new neighbor list size
471628 new neighbor list size
465540 new neighbor list size
453724 new neighbor list size
442298 new neighbor list size
434298 new neighbor list size
428432 new neighbor list size
421932 new neighbor list size
418256 new neighbor list size
414656 new neighbor list size
411276 new neighbor list size
404290 new neighbor list size
398562 new neighbor list size
393958 new neighbor list size
390030 new neighbor list size
386782 new neighbor list size
382868 new neighbor list size
379510 new neighbor list size
376428 new neighbor list size
374024 new neighbor list size
371272 new neighbor list size
368946 new neighbor list size
366958 new neighbor list size
365340 new neighbor list size
363842 new neighbor list size
362166 new neighbor list size
360926 new neighbor list size
360118 new neighbor list size
359558 new neighbor list size
358906 new neighbor list size
358038 new neighbor list size
357402 new neighbor list size
357078 new neighbor list size
Compile time: 1.581 s
Run time: 11.275 s
%time
:html show_plot $ y_plot energies_large_nl
Compile time: 765.950 ms
Run time: 4.562 ms
System state after minimization.
%time
:html render_svg (draw_system 0. 2 state_large_nl'. R ) (Point (0. 0 , 0. 0 ) , Point (L _large, L _large ) )
Compile time: 429.204 ms
Run time: 2.501 s